{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 载入数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#wget http://dl4img-1251985129.cosbj.myqcloud.com/dogcat.tar.gz\n",
    "#!tar -zxvf dogcat.tar.gz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import cv2\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "\n",
    "n = 25000\n",
    "width = 128\n",
    "\n",
    "X = np.zeros((n, width, width, 3), dtype=np.uint8)\n",
    "y = np.zeros((n,), dtype=np.uint8)\n",
    "\n",
    "for i in tqdm(range(n/2)):\n",
    "    X[i] = cv2.resize(cv2.imread('train/cat.%d.jpg' % i), (width, width))\n",
    "    X[i+n/2] = cv2.resize(cv2.imread('train/dog.%d.jpg' % i), (width, width))\n",
    "\n",
    "y[n/2:] = 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 分为训练集和验证集\n",
    "\n",
    "http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 数据集可视化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import random\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'retina'\n",
    "\n",
    "plt.figure(figsize=(12, 10))\n",
    "for i in range(12):\n",
    "    random_index = random.randint(0, n-1)\n",
    "    plt.subplot(3, 4, i+1)\n",
    "    plt.imshow(X[random_index][:,:,::-1])\n",
    "    plt.title(['cat', 'dog'][y[random_index]])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 搭建模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from keras.layers import *\n",
    "from keras.models import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "inputs = Input((width, width, 3))\n",
    "x = inputs\n",
    "for i, layer_num in enumerate([2, 2, 3, 3, 3]):\n",
    "    for j in range(layer_num):\n",
    "        x = Conv2D(32*2**i, 3, padding='same', activation='relu')(x)\n",
    "        x = BatchNormalization()(x)\n",
    "        x = Activation('relu')(x)\n",
    "    x = MaxPooling2D(2)(x)\n",
    "\n",
    "x = GlobalAveragePooling2D()(x)\n",
    "x = Dropout(0.5)(x)\n",
    "x = Dense(1, activation='sigmoid')(x)\n",
    "\n",
    "model = Model(inputs, x)\n",
    "model.compile(optimizer='adam',\n",
    "              loss='binary_crossentropy',\n",
    "              metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# from IPython.display import SVG\n",
    "# from keras.utils.vis_utils import model_to_dot\n",
    "\n",
    "# SVG(model_to_dot(model, show_shapes=True).create(prog='dot', format='svg'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "h = model.fit(X_train, y_train, batch_size=32, epochs=50, validation_data=(X_valid, y_valid))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 4))\n",
    "plt.subplot(1, 2, 1)\n",
    "plt.plot(h.history['loss'])\n",
    "plt.plot(h.history['val_loss'])\n",
    "plt.legend(['loss', 'val_loss'])\n",
    "plt.ylabel('loss')\n",
    "plt.xlabel('epoch')\n",
    "\n",
    "plt.subplot(1, 2, 2)\n",
    "plt.plot(h.history['acc'])\n",
    "plt.plot(h.history['val_acc'])\n",
    "plt.legend(['acc', 'val_acc'])\n",
    "plt.ylabel('acc')\n",
    "plt.xlabel('epoch')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
